26 abc抽象基类
你写了一个基类,希望所有子类都实现某个方法,但Python不会强制要求——子类可以不实现,运行时才会报错。abc模块就是解决这个问题的,它让你定义"抽象方法",子类必须实现这些方法才能实例化。
一、基本用法
1.1 定义抽象基类
python
from abc import ABC, abstractmethod
class Shape(ABC):
@abstractmethod
def area(self):
"""计算面积"""
pass
@abstractmethod
def perimeter(self):
"""计算周长"""
pass
# 不能实例化抽象类
# shape = Shape() # TypeError: Can't instantiate abstract class1.2 实现子类
python
import math
class Circle(Shape):
def __init__(self, radius):
self.radius = radius
def area(self):
return math.pi * self.radius ** 2
def perimeter(self):
return 2 * math.pi * self.radius
class Rectangle(Shape):
def __init__(self, width, height):
self.width = width
self.height = height
def area(self):
return self.width * self.height
def perimeter(self):
return 2 * (self.width + self.height)
# 现在可以实例化
circle = Circle(5)
print(circle.area()) # 78.53981633974483
print(circle.perimeter()) # 31.41592653589793
rect = Rectangle(4, 6)
print(rect.area()) # 241.3 子类必须实现所有抽象方法
python
class IncompleteShape(Shape):
def area(self):
return 0
# 没有实现perimeter
# 不能实例化
# shape = IncompleteShape() # TypeError二、抽象方法的变体
2.1 抽象方法可以有默认实现
python
from abc import ABC, abstractmethod
class Base(ABC):
@abstractmethod
def required_method(self):
pass
@abstractmethod
def optional_method(self):
# 有默认实现,子类可以选择调用
return "默认实现"
class Child(Base):
def required_method(self):
return "必须实现"
def optional_method(self):
# 调用父类的默认实现
return super().optional_method() + " + 子类扩展"
child = Child()
print(child.optional_method()) # 默认实现 + 子类扩展2.2 类方法和静态方法
python
from abc import ABC, abstractmethod
class Factory(ABC):
@classmethod
@abstractmethod
def create(cls):
pass
class ConcreteFactory(Factory):
@classmethod
def create(cls):
return cls()三、注册虚拟子类
3.1 register()
不继承也能被认为是子类。
python
from abc import ABC, abstractmethod
class Shape(ABC):
@abstractmethod
def area(self):
pass
# 注册为虚拟子类
@Shape.register
class MyShape:
def area(self):
return 42
# MyShape现在被认为是Shape的子类
print(isinstance(MyShape(), Shape)) # True
print(issubclass(MyShape, Shape)) # True3.2 subclasshook
更灵活的子类检查。
python
from abc import ABC
class SupportsArea(ABC):
@classmethod
def __subclasshook__(cls, C):
if cls is SupportsArea:
if any("area" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented
# 任何有area方法的类都被认为是子类
class MyShape:
def area(self):
return 42
print(isinstance(MyShape(), SupportsArea)) # True四、实际应用场景
4.1 插件系统
python
from abc import ABC, abstractmethod
class Plugin(ABC):
@abstractmethod
def name(self):
pass
@abstractmethod
def execute(self, data):
pass
class TextPlugin(Plugin):
def name(self):
return "text"
def execute(self, data):
return data.upper()
class JsonPlugin(Plugin):
def name(self):
return "json"
def execute(self, data):
import json
return json.loads(data)
# 插件注册表
plugins = {}
def register_plugin(plugin_cls):
p = plugin_cls()
plugins[p.name()] = p
register_plugin(TextPlugin)
register_plugin(JsonPlugin)4.2 策略模式
python
from abc import ABC, abstractmethod
class SortStrategy(ABC):
@abstractmethod
def sort(self, data):
pass
class BubbleSort(SortStrategy):
def sort(self, data):
# 冒泡排序实现
return sorted(data)
class QuickSort(SortStrategy):
def sort(self, data):
# 快速排序实现
return sorted(data)
class Sorter:
def __init__(self, strategy: SortStrategy):
self.strategy = strategy
def sort(self, data):
return self.strategy.sort(data)
# 使用
sorter = Sorter(BubbleSort())
print(sorter.sort([3, 1, 4, 1, 5]))4.3 数据处理器
python
from abc import ABC, abstractmethod
class DataProcessor(ABC):
@abstractmethod
def load(self, path):
pass
@abstractmethod
def process(self, data):
pass
@abstractmethod
def save(self, data, path):
pass
def run(self, input_path, output_path):
data = self.load(input_path)
processed = self.process(data)
self.save(processed, output_path)
class CsvProcessor(DataProcessor):
def load(self, path):
import csv
with open(path) as f:
return list(csv.reader(f))
def process(self, data):
return [row for row in data if any(row)]
def save(self, data, path):
import csv
with open(path, 'w', newline='') as f:
csv.writer(f).writerows(data)4.4 接口定义
python
from abc import ABC, abstractmethod
from typing import List
class Repository(ABC):
@abstractmethod
def get(self, id):
pass
@abstractmethod
def save(self, entity):
pass
@abstractmethod
def delete(self, id):
pass
@abstractmethod
def list_all(self) -> List:
pass
class UserRepository(Repository):
def __init__(self):
self.users = {}
def get(self, id):
return self.users.get(id)
def save(self, entity):
self.users[entity['id']] = entity
def delete(self, id):
self.users.pop(id, None)
def list_all(self):
return list(self.users.values())五、与Protocol对比
Python 3.8+引入了Protocol,也能定义接口:
python
# abc方式:显式继承
from abc import ABC, abstractmethod
class Shape(ABC):
@abstractmethod
def area(self):
pass
class Circle(Shape): # 必须继承
def area(self):
return 3.14
# Protocol方式:结构化子类型
from typing import Protocol
class ShapeProtocol(Protocol):
def area(self) -> float:
pass
class Circle:
def area(self): # 不需要继承
return 3.14
def print_area(shape: ShapeProtocol):
print(shape.area())
print_area(Circle()) # OK,Circle有area方法| 特性 | abc | Protocol |
|---|---|---|
| 检查方式 | 继承检查 | 结构化检查 |
| 运行时强制 | 是 | 否 |
| 显式继承 | 需要 | 不需要 |
| 灵活性 | 低 | 高 |
六、常见错误
6.1 忘记实现抽象方法
python
from abc import ABC, abstractmethod
class Base(ABC):
@abstractmethod
def method(self):
pass
class Child(Base):
pass
# Child() # TypeError: Can't instantiate abstract class6.2 抽象方法没有body
python
# 错误
class Base(ABC):
@abstractmethod
def method(): # 缺少self
pass
# 正确
class Base(ABC):
@abstractmethod
def method(self):
pass七、总结
abc模块的核心:
| 组件 | 用途 |
|---|---|
ABC | 抽象基类辅助类 |
@abstractmethod | 抽象方法装饰器 |
register() | 注册虚拟子类 |
__subclasshook__ | 自定义子类检查 |
使用场景:
- 定义接口规范
- 强制子类实现特定方法
- 插件系统
- 策略模式
记住:ABC + @abstractmethod就够了。想让子类必须实现某个方法,就把它标记为抽象方法。